{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract Poses from Amass Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib notebook\n",
    "%matplotlib inline\n",
    "\n",
    "import sys, os\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "\n",
    "from human_body_prior.tools.omni_tools import copy2cpu as c2c\n",
    "\n",
    "os.environ['PYOPENGL_PLATFORM'] = 'egl'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Please remember to download the following subdataset from AMASS website: https://amass.is.tue.mpg.de/download.php. Note only download the <u>SMPL+H G</u> data.\n",
    "* ACCD (ACCD)\n",
    "* HDM05 (MPI_HDM05)\n",
    "* TCDHands (TCD_handMocap)\n",
    "* SFU (SFU)\n",
    "* BMLmovi (BMLmovi)\n",
    "* CMU (CMU)\n",
    "* Mosh (MPI_mosh)\n",
    "* EKUT (EKUT)\n",
    "* KIT  (KIT)\n",
    "* Eyes_Janpan_Dataset (Eyes_Janpan_Dataset)\n",
    "* BMLhandball (BMLhandball)\n",
    "* Transitions (Transitions_mocap)\n",
    "* PosePrior (MPI_Limits)\n",
    "* HumanEva (HumanEva)\n",
    "* SSM (SSM_synced)\n",
    "* DFaust (DFaust_67)\n",
    "* TotalCapture (TotalCapture)\n",
    "* BMLrub (BioMotionLab_NTroje)\n",
    "\n",
    "### Unzip all datasets. In the bracket we give the name of the unzipped file folder. Please correct yours to the given names if they are not the same."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Place all files under the directory **./amass_data/**. The directory structure shoud look like the following:  \n",
    "./amass_data/  \n",
    "./amass_data/ACCAD/  \n",
    "./amass_data/BioMotionLab_NTroje/  \n",
    "./amass_data/BMLhandball/  \n",
    "./amass_data/BMLmovi/   \n",
    "./amass_data/CMU/  \n",
    "./amass_data/DFaust_67/  \n",
    "./amass_data/EKUT/  \n",
    "./amass_data/Eyes_Japan_Dataset/  \n",
    "./amass_data/HumanEva/  \n",
    "./amass_data/KIT/  \n",
    "./amass_data/MPI_HDM05/  \n",
    "./amass_data/MPI_Limits/  \n",
    "./amass_data/MPI_mosh/  \n",
    "./amass_data/SFU/  \n",
    "./amass_data/SSM_synced/  \n",
    "./amass_data/TCD_handMocap/  \n",
    "./amass_data/TotalCapture/  \n",
    "./amass_data/Transitions_mocap/  \n",
    "\n",
    "**Please make sure the file path are correct, otherwise it can not succeed.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose the device to run the body model on.\n",
    "comp_device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from human_body_prior.body_model.body_model import BodyModel\n",
    "\n",
    "male_bm_path = './body_models/smplh/male/model.npz'\n",
    "male_dmpl_path = './body_models/dmpls/male/model.npz'\n",
    "\n",
    "female_bm_path = './body_models/smplh/female/model.npz'\n",
    "female_dmpl_path = './body_models/dmpls/female/model.npz'\n",
    "\n",
    "num_betas = 10 # number of body parameters\n",
    "num_dmpls = 8 # number of DMPL parameters\n",
    "\n",
    "male_bm = BodyModel(bm_fname=male_bm_path, num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname=male_dmpl_path).to(comp_device)\n",
    "faces = c2c(male_bm.f)\n",
    "\n",
    "female_bm = BodyModel(bm_fname=female_bm_path, num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname=female_dmpl_path).to(comp_device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = []\n",
    "folders = []\n",
    "dataset_names = []\n",
    "for root, dirs, files in os.walk('./amass_data'):\n",
    "#     print(root, dirs, files)\n",
    "#     for folder in dirs:\n",
    "#         folders.append(os.path.join(root, folder))\n",
    "    folders.append(root)\n",
    "    for name in files:\n",
    "        dataset_name = root.split('/')[2]\n",
    "        if dataset_name not in dataset_names:\n",
    "            dataset_names.append(dataset_name)\n",
    "        paths.append(os.path.join(root, name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_root = './pose_data'\n",
    "save_folders = [folder.replace('./amass_data', './pose_data') for folder in folders]\n",
    "for folder in save_folders:\n",
    "    os.makedirs(folder, exist_ok=True)\n",
    "group_path = [[path for path in paths if name in path] for name in dataset_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans_matrix = np.array([[1.0, 0.0, 0.0],\n",
    "                            [0.0, 0.0, 1.0],\n",
    "                            [0.0, 1.0, 0.0]])\n",
    "ex_fps = 20\n",
    "def amass_to_pose(src_path, save_path):\n",
    "    bdata = np.load(src_path, allow_pickle=True)\n",
    "    fps = 0\n",
    "    try:\n",
    "        fps = bdata['mocap_framerate']\n",
    "        frame_number = bdata['trans'].shape[0]\n",
    "    except:\n",
    "#         print(list(bdata.keys()))\n",
    "        return fps\n",
    "    \n",
    "    fId = 0 # frame id of the mocap sequence\n",
    "    pose_seq = []\n",
    "    if bdata['gender'] == 'male':\n",
    "        bm = male_bm\n",
    "    else:\n",
    "        bm = female_bm\n",
    "    down_sample = int(fps / ex_fps)\n",
    "#     print(frame_number)\n",
    "#     print(fps)\n",
    "\n",
    "    bdata_poses = bdata['poses'][::down_sample,...]\n",
    "    bdata_trans = bdata['trans'][::down_sample,...]\n",
    "    body_parms = {\n",
    "            'root_orient': torch.Tensor(bdata_poses[:, :3]).to(comp_device),\n",
    "            'pose_body': torch.Tensor(bdata_poses[:, 3:66]).to(comp_device),\n",
    "            'pose_hand': torch.Tensor(bdata_poses[:, 66:]).to(comp_device),\n",
    "            'trans': torch.Tensor(bdata_trans).to(comp_device),\n",
    "            'betas': torch.Tensor(np.repeat(bdata['betas'][:num_betas][np.newaxis], repeats=len(bdata_trans), axis=0)).to(comp_device),\n",
    "        }\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        body = bm(**body_parms)\n",
    "    pose_seq_np = body.Jtr.detach().cpu().numpy()\n",
    "    pose_seq_np_n = np.dot(pose_seq_np, trans_matrix)\n",
    "    \n",
    "    \n",
    "    np.save(save_path, pose_seq_np_n)\n",
    "    return fps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "group_path = group_path\n",
    "all_count = sum([len(paths) for paths in group_path])\n",
    "cur_count = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This will take a few hours for all datasets, here we take one dataset as an example\n",
    "\n",
    "To accelerate the process, you could run multiple scripts like this at one time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing: SFU: 100%|██████████| 44/44 [11:42<00:00, 15.97s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processed / All (fps 120): 44/44\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "for paths in group_path:\n",
    "    dataset_name = paths[0].split('/')[2]\n",
    "    pbar = tqdm(paths)\n",
    "    pbar.set_description('Processing: %s'%dataset_name)\n",
    "    fps = 0\n",
    "    for path in pbar:\n",
    "        save_path = path.replace('./amass_data', './pose_data')\n",
    "        save_path = save_path[:-3] + 'npy'\n",
    "        fps = amass_to_pose(path, save_path)\n",
    "        \n",
    "    cur_count += len(paths)\n",
    "    print('Processed / All (fps %d): %d/%d'% (fps, cur_count, all_count) )\n",
    "    time.sleep(0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above code will extract poses from **AMASS** dataset, and put them under directory **\"./pose_data\"**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The source data from **HumanAct12** is already included in **\"./pose_data\"** in this repository. You need to **unzip** it right in this folder."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segment, Mirror and Relocate Motions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import codecs as cs\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from os.path import join as pjoin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def swap_left_right(data):\n",
    "    assert len(data.shape) == 3 and data.shape[-1] == 3\n",
    "    data = data.copy()\n",
    "    data[..., 0] *= -1\n",
    "    right_chain = [2, 5, 8, 11, 14, 17, 19, 21]\n",
    "    left_chain = [1, 4, 7, 10, 13, 16, 18, 20]\n",
    "    left_hand_chain = [22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30]\n",
    "    right_hand_chain = [43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51]\n",
    "    tmp = data[:, right_chain]\n",
    "    data[:, right_chain] = data[:, left_chain]\n",
    "    data[:, left_chain] = tmp\n",
    "    if data.shape[1] > 24:\n",
    "        tmp = data[:, right_hand_chain]\n",
    "        data[:, right_hand_chain] = data[:, left_hand_chain]\n",
    "        data[:, left_hand_chain] = tmp\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_path = './index.csv'\n",
    "save_dir = './joints'\n",
    "index_file = pd.read_csv(index_path)\n",
    "total_amount = index_file.shape[0]\n",
    "fps = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14616/14616 [00:30<00:00, 486.52it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(total_amount)):\n",
    "    source_path = index_file.loc[i]['source_path']\n",
    "    new_name = index_file.loc[i]['new_name']\n",
    "    data = np.load(source_path)\n",
    "    start_frame = index_file.loc[i]['start_frame']\n",
    "    end_frame = index_file.loc[i]['end_frame']\n",
    "    if 'humanact12' not in source_path:\n",
    "        if 'Eyes_Japan_Dataset' in source_path:\n",
    "            data = data[3*fps:]\n",
    "        if 'MPI_HDM05' in source_path:\n",
    "            data = data[3*fps:]\n",
    "        if 'TotalCapture' in source_path:\n",
    "            data = data[1*fps:]\n",
    "        if 'MPI_Limits' in source_path:\n",
    "            data = data[1*fps:]\n",
    "        if 'Transitions_mocap' in source_path:\n",
    "            data = data[int(0.5*fps):]\n",
    "        data = data[start_frame:end_frame]\n",
    "        data[..., 0] *= -1\n",
    "    \n",
    "    data_m = swap_left_right(data)\n",
    "#     save_path = pjoin(save_dir, )\n",
    "    np.save(pjoin(save_dir, new_name), data)\n",
    "    np.save(pjoin(save_dir, 'M'+new_name), data_m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch_render]",
   "language": "python",
   "name": "conda-env-torch_render-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
